{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NAS 基准测试示例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint\n",
    "import time\n",
    "\n",
    "from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats\n",
    "from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats\n",
    "from nni.nas.benchmarks.nds import query_nds_trial_stats\n",
    "\n",
    "ti = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NAS-Bench-101"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用以下网络结构作为示例：\n",
    "\n",
    "![nas-101](../../img/nas-bench-101-example.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'config': {'arch': {'input1': [0],\n                     'input2': [1],\n                     'input3': [2],\n                     'input4': [0],\n                     'input5': [0, 3, 4],\n                     'input6': [2, 5],\n                     'op1': 'conv3x3-bn-relu',\n                     'op2': 'maxpool3x3',\n                     'op3': 'conv3x3-bn-relu',\n                     'op4': 'conv3x3-bn-relu',\n                     'op5': 'conv1x1-bn-relu'},\n            'hash': '00005c142e6f48ac74fdcf73e3439874',\n            'id': 4,\n            'num_epochs': 108,\n            'num_vertices': 7},\n 'id': 10,\n 'intermediates': [{'current_epoch': 54,\n                    'id': 19,\n                    'test_acc': 77.40384340286255,\n                    'train_acc': 82.82251358032227,\n                    'training_time': 883.4580078125,\n                    'valid_acc': 77.76442170143127},\n                   {'current_epoch': 108,\n                    'id': 20,\n                    'test_acc': 92.11738705635071,\n                    'train_acc': 100.0,\n                    'training_time': 1769.1279296875,\n                    'valid_acc': 92.41786599159241}],\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n                     'input2': [1],\n                     'input3': [2],\n                     'input4': [0],\n                     'input5': [0, 3, 4],\n                     'input6': [2, 5],\n                     'op1': 'conv3x3-bn-relu',\n                     'op2': 'maxpool3x3',\n                     'op3': 'conv3x3-bn-relu',\n                     'op4': 'conv3x3-bn-relu',\n                     'op5': 'conv1x1-bn-relu'},\n            'hash': '00005c142e6f48ac74fdcf73e3439874',\n            'id': 4,\n            'num_epochs': 108,\n            'num_vertices': 7},\n 'id': 11,\n 'intermediates': [{'current_epoch': 54,\n                    'id': 21,\n                    'test_acc': 82.04126358032227,\n                    'train_acc': 87.96073794364929,\n                    'training_time': 883.6810302734375,\n                    'valid_acc': 82.91265964508057},\n                   {'current_epoch': 108,\n                    'id': 22,\n                    'test_acc': 91.90705418586731,\n                    'train_acc': 100.0,\n                    'training_time': 1768.2509765625,\n                    'valid_acc': 92.45793223381042}],\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n                     'input2': [1],\n                     'input3': [2],\n                     'input4': [0],\n                     'input5': [0, 3, 4],\n                     'input6': [2, 5],\n                     'op1': 'conv3x3-bn-relu',\n                     'op2': 'maxpool3x3',\n                     'op3': 'conv3x3-bn-relu',\n                     'op4': 'conv3x3-bn-relu',\n                     'op5': 'conv1x1-bn-relu'},\n            'hash': '00005c142e6f48ac74fdcf73e3439874',\n            'id': 4,\n            'num_epochs': 108,\n            'num_vertices': 7},\n 'id': 12,\n 'intermediates': [{'current_epoch': 54,\n                    'id': 23,\n                    'test_acc': 80.58894276618958,\n                    'train_acc': 86.34815812110901,\n                    'training_time': 883.4569702148438,\n                    'valid_acc': 81.1598539352417},\n                   {'current_epoch': 108,\n                    'id': 24,\n                    'test_acc': 92.15745329856873,\n                    'train_acc': 100.0,\n                    'training_time': 1768.9759521484375,\n                    'valid_acc': 93.04887652397156}],\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
    }
   ],
   "source": [
    "arch = {\n",
    "    'op1': 'conv3x3-bn-relu',\n",
    "    'op2': 'maxpool3x3',\n",
    "    'op3': 'conv3x3-bn-relu',\n",
    "    'op4': 'conv3x3-bn-relu',\n",
    "    'op5': 'conv1x1-bn-relu',\n",
    "    'input1': [0],\n",
    "    'input2': [1],\n",
    "    'input3': [2],\n",
    "    'input4': [0],\n",
    "    'input5': [0, 3, 4],\n",
    "    'input6': [2, 5]\n",
    "}\n",
    "for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n",
    "    pprint.pprint(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一个 NAS-Bench-101 的网络结构可以被训练多次。 生成器返回的每一个元素是一个字典，包含了该 Trial 设置（网络结构+超参数）中其中一个训练结果，如训练集/验证集/测试集准确率，训练时间，Epoch数等等。 NAS-Bench-201 和 NDS 的结果遵循了相似的格式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NAS-Bench-201"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用以下网络结构作为示例：\n",
    "\n",
    "![nas-201](../../img/nas-bench-201-example.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'config': {'arch': {'0_1': 'avg_pool_3x3',\n                     '0_2': 'conv_1x1',\n                     '0_3': 'conv_1x1',\n                     '1_2': 'skip_connect',\n                     '1_3': 'skip_connect',\n                     '2_3': 'skip_connect'},\n            'dataset': 'cifar100',\n            'id': 7,\n            'num_cells': 5,\n            'num_channels': 16,\n            'num_epochs': 200},\n 'flops': 15.65322,\n 'id': 3,\n 'latency': 0.013182918230692545,\n 'ori_test_acc': 53.11,\n 'ori_test_evaluation_time': 1.0195916947864352,\n 'ori_test_loss': 1.7307863704681397,\n 'parameters': 0.135156,\n 'seed': 999,\n 'test_acc': 53.07999995727539,\n 'test_evaluation_time': 0.5097958473932176,\n 'test_loss': 1.731276072692871,\n 'train_acc': 57.82,\n 'train_loss': 1.5116578379058838,\n 'training_time': 2888.4371995925903,\n 'valid_acc': 53.14000000610351,\n 'valid_evaluation_time': 0.5097958473932176,\n 'valid_loss': 1.7302966793060304}\n{'config': {'arch': {'0_1': 'avg_pool_3x3',\n                     '0_2': 'conv_1x1',\n                     '0_3': 'conv_1x1',\n                     '1_2': 'skip_connect',\n                     '1_3': 'skip_connect',\n                     '2_3': 'skip_connect'},\n            'dataset': 'cifar100',\n            'id': 7,\n            'num_cells': 5,\n            'num_channels': 16,\n            'num_epochs': 200},\n 'flops': 15.65322,\n 'id': 7,\n 'latency': 0.013182918230692545,\n 'ori_test_acc': 51.93,\n 'ori_test_evaluation_time': 1.0195916947864352,\n 'ori_test_loss': 1.7572312774658203,\n 'parameters': 0.135156,\n 'seed': 777,\n 'test_acc': 51.979999938964845,\n 'test_evaluation_time': 0.5097958473932176,\n 'test_loss': 1.7429540189743042,\n 'train_acc': 57.578,\n 'train_loss': 1.5114233912658692,\n 'training_time': 2888.4371995925903,\n 'valid_acc': 51.88,\n 'valid_evaluation_time': 0.5097958473932176,\n 'valid_loss': 1.7715086591720581}\n{'config': {'arch': {'0_1': 'avg_pool_3x3',\n                     '0_2': 'conv_1x1',\n                     '0_3': 'conv_1x1',\n                     '1_2': 'skip_connect',\n                     '1_3': 'skip_connect',\n                     '2_3': 'skip_connect'},\n            'dataset': 'cifar100',\n            'id': 7,\n            'num_cells': 5,\n            'num_channels': 16,\n            'num_epochs': 200},\n 'flops': 15.65322,\n 'id': 11,\n 'latency': 0.013182918230692545,\n 'ori_test_acc': 53.38,\n 'ori_test_evaluation_time': 1.0195916947864352,\n 'ori_test_loss': 1.7281623031616211,\n 'parameters': 0.135156,\n 'seed': 888,\n 'test_acc': 53.67999998779297,\n 'test_evaluation_time': 0.5097958473932176,\n 'test_loss': 1.7327697801589965,\n 'train_acc': 57.792,\n 'train_loss': 1.5091403088760376,\n 'training_time': 2888.4371995925903,\n 'valid_acc': 53.08000000610352,\n 'valid_evaluation_time': 0.5097958473932176,\n 'valid_loss': 1.7235548280715942}\n"
    }
   ],
   "source": [
    "arch = {\n",
    "    '0_1': 'avg_pool_3x3',\n",
    "    '0_2': 'conv_1x1',\n",
    "    '1_2': 'skip_connect',\n",
    "    '0_3': 'conv_1x1',\n",
    "    '1_3': 'skip_connect',\n",
    "    '2_3': 'skip_connect'\n",
    "}\n",
    "for t in query_nb201_trial_stats(arch, 200, 'cifar100'):\n",
    "    pprint.pprint(t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "中间结果也可得到。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'id': 4, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 12, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 12\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n"
    }
   ],
   "source": [
    "for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n",
    "    print(t['config'])\n",
    "    print('Intermediates:', len(t['intermediates']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NDS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用以下网络结构作为示例：<br>\n",
    "![nds](../../img/nas-bench-nds-example.png)\n",
    "\n",
    "这里， `bot_muls`, `ds`, `num_gs`, `ss` 和 `ws` 分别代表 \"bottleneck multipliers\", \"depths\", \"number of groups\", \"strides\" and \"widths\"。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'best_test_acc': 90.48,\n 'best_train_acc': 96.356,\n 'best_train_loss': 0.116,\n 'config': {'base_lr': 0.1,\n            'cell_spec': {},\n            'dataset': 'cifar10',\n            'generator': 'random',\n            'id': 45505,\n            'model_family': 'residual_bottleneck',\n            'model_spec': {'bot_muls': [0.0, 0.25, 0.25, 0.25],\n                           'ds': [1, 16, 1, 4],\n                           'num_gs': [1, 2, 1, 2],\n                           'ss': [1, 1, 2, 2],\n                           'ws': [16, 64, 128, 16]},\n            'num_epochs': 100,\n            'proposer': 'resnext-a',\n            'weight_decay': 0.0005},\n 'final_test_acc': 90.39,\n 'final_train_acc': 96.298,\n 'final_train_loss': 0.116,\n 'flops': 69.890986,\n 'id': 45505,\n 'iter_time': 0.065,\n 'parameters': 0.083002,\n 'seed': 1}\n"
    }
   ],
   "source": [
    "model_spec = {\n",
    "    'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
    "    'ds': [1, 16, 1, 4],\n",
    "    'num_gs': [1, 2, 1, 2],\n",
    "    'ss': [1, 1, 2, 2],\n",
    "    'ws': [16, 64, 128, 16]\n",
    "}\n",
    "# Use none as a wildcard\n",
    "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10'):\n",
    "    pprint.pprint(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "[{'current_epoch': 1,\n  'id': 4494501,\n  'test_acc': 41.76,\n  'train_acc': 30.421000000000006,\n  'train_loss': 1.793},\n {'current_epoch': 2,\n  'id': 4494502,\n  'test_acc': 54.66,\n  'train_acc': 47.24,\n  'train_loss': 1.415},\n {'current_epoch': 3,\n  'id': 4494503,\n  'test_acc': 59.97,\n  'train_acc': 56.983,\n  'train_loss': 1.179},\n {'current_epoch': 4,\n  'id': 4494504,\n  'test_acc': 62.91,\n  'train_acc': 61.955,\n  'train_loss': 1.048},\n {'current_epoch': 5,\n  'id': 4494505,\n  'test_acc': 66.16,\n  'train_acc': 64.493,\n  'train_loss': 0.983},\n {'current_epoch': 6,\n  'id': 4494506,\n  'test_acc': 66.5,\n  'train_acc': 66.274,\n  'train_loss': 0.937},\n {'current_epoch': 7,\n  'id': 4494507,\n  'test_acc': 67.55,\n  'train_acc': 67.426,\n  'train_loss': 0.907},\n {'current_epoch': 8,\n  'id': 4494508,\n  'test_acc': 69.45,\n  'train_acc': 68.45400000000001,\n  'train_loss': 0.878},\n {'current_epoch': 9,\n  'id': 4494509,\n  'test_acc': 70.14,\n  'train_acc': 69.295,\n  'train_loss': 0.857},\n {'current_epoch': 10,\n  'id': 4494510,\n  'test_acc': 69.47,\n  'train_acc': 70.304,\n  'train_loss': 0.832}]\n"
    }
   ],
   "source": [
    "model_spec = {\n",
    "    'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
    "    'ds': [1, 16, 1, 4],\n",
    "    'num_gs': [1, 2, 1, 2],\n",
    "    'ss': [1, 1, 2, 2],\n",
    "    'ws': [16, 64, 128, 16]\n",
    "}\n",
    "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n",
    "    pprint.pprint(t['intermediates'][:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'best_test_acc': 93.58,\n 'best_train_acc': 99.772,\n 'best_train_loss': 0.011,\n 'config': {'base_lr': 0.1,\n            'cell_spec': {},\n            'dataset': 'cifar10',\n            'generator': 'random',\n            'id': 108998,\n            'model_family': 'residual_basic',\n            'model_spec': {'ds': [1, 12, 12, 12],\n                           'ss': [1, 1, 2, 2],\n                           'ws': [16, 24, 24, 40]},\n            'num_epochs': 100,\n            'proposer': 'resnet',\n            'weight_decay': 0.0005},\n 'final_test_acc': 93.49,\n 'final_train_acc': 99.772,\n 'final_train_loss': 0.011,\n 'flops': 184.519578,\n 'id': 108998,\n 'iter_time': 0.059,\n 'parameters': 0.594138,\n 'seed': 1}\n"
    }
   ],
   "source": [
    "model_spec = {'ds': [1, 12, 12, 12], 'ss': [1, 1, 2, 2], 'ws': [16, 24, 24, 40]}\n",
    "for t in query_nds_trial_stats('residual_basic', 'resnet', 'random', model_spec, {}, 'cifar10'):\n",
    "    pprint.pprint(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'best_test_acc': 84.5,\n 'best_train_acc': 89.66499999999999,\n 'best_train_loss': 0.302,\n 'config': {'base_lr': 0.1,\n            'cell_spec': {},\n            'dataset': 'cifar10',\n            'generator': 'random',\n            'id': 139492,\n            'model_family': 'vanilla',\n            'model_spec': {'ds': [1, 12, 12, 12],\n                           'ss': [1, 1, 2, 2],\n                           'ws': [16, 24, 32, 40]},\n            'num_epochs': 100,\n            'proposer': 'vanilla',\n            'weight_decay': 0.0005},\n 'final_test_acc': 84.35,\n 'final_train_acc': 89.633,\n 'final_train_loss': 0.303,\n 'flops': 208.36393,\n 'id': 154692,\n 'iter_time': 0.058,\n 'parameters': 0.68977,\n 'seed': 1}\n"
    }
   ],
   "source": [
    "# get the first one\n",
    "pprint.pprint(next(query_nds_trial_stats('vanilla', None, None, None, None, None)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "{'best_test_acc': 93.37,\n 'best_train_acc': 99.91,\n 'best_train_loss': 0.006,\n 'config': {'base_lr': 0.1,\n            'cell_spec': {'normal_0_input_x': 0,\n                          'normal_0_input_y': 1,\n                          'normal_0_op_x': 'avg_pool_3x3',\n                          'normal_0_op_y': 'conv_7x1_1x7',\n                          'normal_1_input_x': 2,\n                          'normal_1_input_y': 0,\n                          'normal_1_op_x': 'sep_conv_3x3',\n                          'normal_1_op_y': 'sep_conv_5x5',\n                          'normal_2_input_x': 2,\n                          'normal_2_input_y': 2,\n                          'normal_2_op_x': 'dil_sep_conv_3x3',\n                          'normal_2_op_y': 'dil_sep_conv_3x3',\n                          'normal_3_input_x': 4,\n                          'normal_3_input_y': 4,\n                          'normal_3_op_x': 'skip_connect',\n                          'normal_3_op_y': 'dil_sep_conv_3x3',\n                          'normal_4_input_x': 2,\n                          'normal_4_input_y': 4,\n                          'normal_4_op_x': 'conv_7x1_1x7',\n                          'normal_4_op_y': 'sep_conv_3x3',\n                          'normal_concat': [3, 5, 6],\n                          'reduce_0_input_x': 0,\n                          'reduce_0_input_y': 1,\n                          'reduce_0_op_x': 'avg_pool_3x3',\n                          'reduce_0_op_y': 'dil_sep_conv_3x3',\n                          'reduce_1_input_x': 0,\n                          'reduce_1_input_y': 0,\n                          'reduce_1_op_x': 'sep_conv_3x3',\n                          'reduce_1_op_y': 'sep_conv_3x3',\n                          'reduce_2_input_x': 2,\n                          'reduce_2_input_y': 0,\n                          'reduce_2_op_x': 'skip_connect',\n                          'reduce_2_op_y': 'sep_conv_7x7',\n                          'reduce_3_input_x': 4,\n                          'reduce_3_input_y': 4,\n                          'reduce_3_op_x': 'conv_7x1_1x7',\n                          'reduce_3_op_y': 'skip_connect',\n                          'reduce_4_input_x': 0,\n                          'reduce_4_input_y': 5,\n                          'reduce_4_op_x': 'conv_7x1_1x7',\n                          'reduce_4_op_y': 'conv_7x1_1x7',\n                          'reduce_concat': [3, 6]},\n            'dataset': 'cifar10',\n            'generator': 'random',\n            'id': 1,\n            'model_family': 'nas_cell',\n            'model_spec': {'aux': False,\n                           'depth': 12,\n                           'drop_prob': 0.0,\n                           'num_nodes_normal': 5,\n                           'num_nodes_reduce': 5,\n                           'width': 32},\n            'num_epochs': 100,\n            'proposer': 'amoeba',\n            'weight_decay': 0.0005},\n 'final_test_acc': 93.27,\n 'final_train_acc': 99.91,\n 'final_train_loss': 0.006,\n 'flops': 664.400586,\n 'id': 1,\n 'iter_time': 0.281,\n 'parameters': 4.190314,\n 'seed': 1}\n"
    }
   ],
   "source": [
    "# count number\n",
    "model_spec = {'num_nodes_normal': 5, 'num_nodes_reduce': 5, 'depth': 12, 'width': 32, 'aux': False, 'drop_prob': 0.0}\n",
    "cell_spec = {\n",
    "    'normal_0_op_x': 'avg_pool_3x3',\n",
    "    'normal_0_input_x': 0,\n",
    "    'normal_0_op_y': 'conv_7x1_1x7',\n",
    "    'normal_0_input_y': 1,\n",
    "    'normal_1_op_x': 'sep_conv_3x3',\n",
    "    'normal_1_input_x': 2,\n",
    "    'normal_1_op_y': 'sep_conv_5x5',\n",
    "    'normal_1_input_y': 0,\n",
    "    'normal_2_op_x': 'dil_sep_conv_3x3',\n",
    "    'normal_2_input_x': 2,\n",
    "    'normal_2_op_y': 'dil_sep_conv_3x3',\n",
    "    'normal_2_input_y': 2,\n",
    "    'normal_3_op_x': 'skip_connect',\n",
    "    'normal_3_input_x': 4,\n",
    "    'normal_3_op_y': 'dil_sep_conv_3x3',\n",
    "    'normal_3_input_y': 4,\n",
    "    'normal_4_op_x': 'conv_7x1_1x7',\n",
    "    'normal_4_input_x': 2,\n",
    "    'normal_4_op_y': 'sep_conv_3x3',\n",
    "    'normal_4_input_y': 4,\n",
    "    'normal_concat': [3, 5, 6],\n",
    "    'reduce_0_op_x': 'avg_pool_3x3',\n",
    "    'reduce_0_input_x': 0,\n",
    "    'reduce_0_op_y': 'dil_sep_conv_3x3',\n",
    "    'reduce_0_input_y': 1,\n",
    "    'reduce_1_op_x': 'sep_conv_3x3',\n",
    "    'reduce_1_input_x': 0,\n",
    "    'reduce_1_op_y': 'sep_conv_3x3',\n",
    "    'reduce_1_input_y': 0,\n",
    "    'reduce_2_op_x': 'skip_connect',\n",
    "    'reduce_2_input_x': 2,\n",
    "    'reduce_2_op_y': 'sep_conv_7x7',\n",
    "    'reduce_2_input_y': 0,\n",
    "    'reduce_3_op_x': 'conv_7x1_1x7',\n",
    "    'reduce_3_input_x': 4,\n",
    "    'reduce_3_op_y': 'skip_connect',\n",
    "    'reduce_3_input_y': 4,\n",
    "    'reduce_4_op_x': 'conv_7x1_1x7',\n",
    "    'reduce_4_input_x': 0,\n",
    "    'reduce_4_op_y': 'conv_7x1_1x7',\n",
    "    'reduce_4_input_y': 5,\n",
    "    'reduce_concat': [3, 6]\n",
    "}\n",
    "\n",
    "for t in query_nds_trial_stats('nas_cell', None, None, model_spec, cell_spec, 'cifar10'):\n",
    "    assert t['config']['model_spec'] == model_spec\n",
    "    assert t['config']['cell_spec'] == cell_spec\n",
    "    pprint.pprint(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "NDS (amoeba) count: 5107\n"
    }
   ],
   "source": [
    "# count number\n",
    "print('NDS (amoeba) count:', len(list(query_nds_trial_stats(None, 'amoeba', None, None, None, None, None))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Elapsed time:  2.2023813724517822 seconds\n"
    }
   ],
   "source": [
    "print('Elapsed time: ', time.time() - ti, 'seconds')"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "version": "3.6.10-final"
  },
  "orig_nbformat": 2,
  "file_extension": ".py",
  "mimetype": "text/x-python",
  "name": "python",
  "npconvert_exporter": "python",
  "pygments_lexer": "ipython3",
  "version": 3,
  "kernelspec": {
   "name": "python361064bitnnilatestcondabff8d66a619a4d26af34fe0fe687c7b0",
   "display_name": "Python 3.6.10 64-bit ('nnilatest': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}